{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# CORN MLP for ordinal regression and deep learning -- cement strength dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This tutorial explains how to train a deep neural network (here: multilayer perceptron) with the CORN loss function for ordinal regression. \n",
    "\n",
    "**CORN reference:**\n",
    "\n",
    "- Xintong Shi, Wenzhi Cao, and Sebastian Raschka \n",
    "[Deep Neural Networks for Rank-Consistent Ordinal Regression Based On Conditional Probabilities](https://arxiv.org/abs/2111.08851)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 0 -- Obtaining and preparing the cement_strength dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will be using the cement_strength dataset from [https://github.com/gagolews/ordinal_regression_data/blob/master/cement_strength.csv](https://github.com/gagolews/ordinal_regression_data/blob/master/cement_strength.csv).\n",
    "\n",
    "First, we are going to download and prepare the and save it as CSV files locally. This is a general procedure that is not specific to CORN.\n",
    "\n",
    "This dataset has 5 ordinal labels (1, 2, 3, 4, and 5). Note that we require abels to be starting at 0, which is why we subtract \"1\" from the label column."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of features: 8\n",
      "Number of examples: 998\n",
      "Labels: [0 1 2 3 4]\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "data_df = pd.read_csv(\"https://raw.githubusercontent.com/gagolews/ordinal_regression_data/master/cement_strength.csv\")\n",
    " \n",
    "data_df[\"response\"] = data_df[\"response\"]-1 # labels should start at 0\n",
    "\n",
    "data_labels = data_df[\"response\"]\n",
    "data_features = data_df.loc[:, [\"V1\", \"V2\", \"V3\", \"V4\", \"V5\", \"V6\", \"V7\", \"V8\"]]\n",
    "\n",
    "print('Number of features:', data_features.shape[1])\n",
    "print('Number of examples:', data_features.shape[0])\n",
    "print('Labels:', np.unique(data_labels.values))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "### Split into training and test data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "\n",
    "X_train, X_test, y_train, y_test = train_test_split(\n",
    "    data_features.values,\n",
    "    data_labels.values,\n",
    "    test_size=0.2,\n",
    "    random_state=1,\n",
    "    stratify=data_labels.values)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Standardize features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.preprocessing import StandardScaler\n",
    "\n",
    "\n",
    "sc = StandardScaler()\n",
    "X_train_std = sc.fit_transform(X_train)\n",
    "X_test_std = sc.transform(X_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1 -- Setting up the dataset and dataloader"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this section, we set up the data set and data loaders. This is a general procedure that is not specific to the method."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training on cuda:0\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "\n",
    "##########################\n",
    "### SETTINGS\n",
    "##########################\n",
    "\n",
    "# Hyperparameters\n",
    "random_seed = 1\n",
    "learning_rate = 0.001\n",
    "num_epochs = 50\n",
    "batch_size = 128\n",
    "\n",
    "# Architecture\n",
    "NUM_CLASSES = 5\n",
    "\n",
    "# Other\n",
    "DEVICE = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "print('Training on', DEVICE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import Dataset\n",
    "\n",
    "\n",
    "class MyDataset(Dataset):\n",
    "\n",
    "    def __init__(self, feature_array, label_array, dtype=np.float32):\n",
    "    \n",
    "        self.features = feature_array.astype(np.float32)\n",
    "        self.labels = label_array\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        inputs = self.features[index]\n",
    "        label = self.labels[index]\n",
    "        return inputs, label\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.labels.shape[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Input batch dimensions: torch.Size([128, 8])\n",
      "Input label dimensions: torch.Size([128])\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "\n",
    "# Note transforms.ToTensor() scales input images\n",
    "# to 0-1 range\n",
    "train_dataset = MyDataset(X_train_std, y_train)\n",
    "test_dataset = MyDataset(X_test_std, y_test)\n",
    "\n",
    "\n",
    "train_loader = DataLoader(dataset=train_dataset,\n",
    "                          batch_size=batch_size,\n",
    "                          shuffle=True, # want to shuffle the dataset\n",
    "                          num_workers=0) # number processes/CPUs to use\n",
    "\n",
    "test_loader = DataLoader(dataset=test_dataset,\n",
    "                         batch_size=batch_size,\n",
    "                         shuffle=False,\n",
    "                         num_workers=0)\n",
    "\n",
    "# Checking the dataset\n",
    "for inputs, labels in train_loader:  \n",
    "    print('Input batch dimensions:', inputs.shape)\n",
    "    print('Input label dimensions:', labels.shape)\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2 - Equipping MLP with a CORN layer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this section, we are implementing a simple MLP for ordinal regression with CORN. Note that the only specific modification required is setting the number of output of the last layer (a fully connected layer) to the number of classes - 1 (these correspond to the binary tasks used in the extended binary classification as described in the paper)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MLP(torch.nn.Module):\n",
    "\n",
    "    def __init__(self, in_features, num_classes, num_hidden_1=300, num_hidden_2=300):\n",
    "        super().__init__()\n",
    "        \n",
    "        self.my_network = torch.nn.Sequential(\n",
    "            \n",
    "            # 1st hidden layer\n",
    "            torch.nn.Linear(in_features, num_hidden_1, bias=False),\n",
    "            torch.nn.LeakyReLU(),\n",
    "            torch.nn.Dropout(0.2),\n",
    "            torch.nn.BatchNorm1d(num_hidden_1),\n",
    "            \n",
    "            # 2nd hidden layer\n",
    "            torch.nn.Linear(num_hidden_1, num_hidden_2, bias=False),\n",
    "            torch.nn.LeakyReLU(),\n",
    "            torch.nn.Dropout(0.2),\n",
    "            torch.nn.BatchNorm1d(num_hidden_2),\n",
    "            \n",
    "            ### Specify CORN layer\n",
    "            torch.nn.Linear(num_hidden_2, (num_classes-1))\n",
    "            ###--------------------------------------------------------------------###\n",
    "        )\n",
    "                \n",
    "    def forward(self, x):\n",
    "        logits = self.my_network(x)\n",
    "        return logits\n",
    "    \n",
    "    \n",
    "    \n",
    "torch.manual_seed(random_seed)\n",
    "model = MLP(in_features=8, num_classes=NUM_CLASSES)\n",
    "model.to(DEVICE)\n",
    "\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3 - Using the CORN loss for model training"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "During training, all you need to do is to use the `corn_loss` provided via `coral_pytorch`. The loss function will take care of the conditional training set processing and modeling the conditional probabilities used in the chain rule (aka general product rule). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "#!pip install coral-pytorch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 001/050 | Batch 000/007 | Cost: 44.5581\n",
      "Epoch: 002/050 | Batch 000/007 | Cost: 37.2149\n",
      "Epoch: 003/050 | Batch 000/007 | Cost: 31.0465\n",
      "Epoch: 004/050 | Batch 000/007 | Cost: 31.5560\n",
      "Epoch: 005/050 | Batch 000/007 | Cost: 26.1428\n",
      "Epoch: 006/050 | Batch 000/007 | Cost: 24.7515\n",
      "Epoch: 007/050 | Batch 000/007 | Cost: 23.3453\n",
      "Epoch: 008/050 | Batch 000/007 | Cost: 23.9005\n",
      "Epoch: 009/050 | Batch 000/007 | Cost: 21.4760\n",
      "Epoch: 010/050 | Batch 000/007 | Cost: 22.7869\n",
      "Epoch: 011/050 | Batch 000/007 | Cost: 22.6931\n",
      "Epoch: 012/050 | Batch 000/007 | Cost: 21.0200\n",
      "Epoch: 013/050 | Batch 000/007 | Cost: 21.9232\n",
      "Epoch: 014/050 | Batch 000/007 | Cost: 19.6500\n",
      "Epoch: 015/050 | Batch 000/007 | Cost: 18.7777\n",
      "Epoch: 016/050 | Batch 000/007 | Cost: 20.1790\n",
      "Epoch: 017/050 | Batch 000/007 | Cost: 19.0979\n",
      "Epoch: 018/050 | Batch 000/007 | Cost: 17.8689\n",
      "Epoch: 019/050 | Batch 000/007 | Cost: 18.2110\n",
      "Epoch: 020/050 | Batch 000/007 | Cost: 16.4955\n",
      "Epoch: 021/050 | Batch 000/007 | Cost: 17.3134\n",
      "Epoch: 022/050 | Batch 000/007 | Cost: 15.9069\n",
      "Epoch: 023/050 | Batch 000/007 | Cost: 15.6398\n",
      "Epoch: 024/050 | Batch 000/007 | Cost: 14.6094\n",
      "Epoch: 025/050 | Batch 000/007 | Cost: 14.0235\n",
      "Epoch: 026/050 | Batch 000/007 | Cost: 13.4980\n",
      "Epoch: 027/050 | Batch 000/007 | Cost: 15.2777\n",
      "Epoch: 028/050 | Batch 000/007 | Cost: 12.3507\n",
      "Epoch: 029/050 | Batch 000/007 | Cost: 12.4901\n",
      "Epoch: 030/050 | Batch 000/007 | Cost: 14.0257\n",
      "Epoch: 031/050 | Batch 000/007 | Cost: 13.6611\n",
      "Epoch: 032/050 | Batch 000/007 | Cost: 12.7932\n",
      "Epoch: 033/050 | Batch 000/007 | Cost: 12.7365\n",
      "Epoch: 034/050 | Batch 000/007 | Cost: 12.3747\n",
      "Epoch: 035/050 | Batch 000/007 | Cost: 11.5628\n",
      "Epoch: 036/050 | Batch 000/007 | Cost: 15.1797\n",
      "Epoch: 037/050 | Batch 000/007 | Cost: 12.0188\n",
      "Epoch: 038/050 | Batch 000/007 | Cost: 11.1485\n",
      "Epoch: 039/050 | Batch 000/007 | Cost: 13.3310\n",
      "Epoch: 040/050 | Batch 000/007 | Cost: 11.3578\n",
      "Epoch: 041/050 | Batch 000/007 | Cost: 16.6671\n",
      "Epoch: 042/050 | Batch 000/007 | Cost: 17.0777\n",
      "Epoch: 043/050 | Batch 000/007 | Cost: 13.1835\n",
      "Epoch: 044/050 | Batch 000/007 | Cost: 13.2365\n",
      "Epoch: 045/050 | Batch 000/007 | Cost: 10.8442\n",
      "Epoch: 046/050 | Batch 000/007 | Cost: 9.3964\n",
      "Epoch: 047/050 | Batch 000/007 | Cost: 12.5001\n",
      "Epoch: 048/050 | Batch 000/007 | Cost: 8.8510\n",
      "Epoch: 049/050 | Batch 000/007 | Cost: 12.1387\n",
      "Epoch: 050/050 | Batch 000/007 | Cost: 10.9765\n"
     ]
    }
   ],
   "source": [
    "from coral_pytorch.losses import corn_loss\n",
    "\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    \n",
    "    model = model.train()\n",
    "    for batch_idx, (features, class_labels) in enumerate(train_loader):\n",
    "\n",
    "        class_labels = class_labels.to(DEVICE)\n",
    "        features = features.to(DEVICE)\n",
    "        logits = model(features)\n",
    "        \n",
    "        #### CORN loss \n",
    "        loss = corn_loss(logits, class_labels, NUM_CLASSES)\n",
    "        ###--------------------------------------------------------------------###   \n",
    "        \n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        ### LOGGING\n",
    "        if not batch_idx % 200:\n",
    "            print ('Epoch: %03d/%03d | Batch %03d/%03d | Cost: %.4f' \n",
    "                   %(epoch+1, num_epochs, batch_idx, \n",
    "                     len(train_loader), loss))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4 -- Evaluate model\n",
    "\n",
    "Finally, after model training, we can evaluate the performance of the model. For example, via the mean absolute error and mean squared error measures.\n",
    "\n",
    "For this, we are going to use the `corn_label_from_logits` utility function from `coral_pytorch` to convert the probabilities back to the orginal label.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "from coral_pytorch.dataset import corn_label_from_logits\n",
    "\n",
    "\n",
    "def compute_mae_and_mse(model, data_loader, device):\n",
    "\n",
    "    with torch.no_grad():\n",
    "    \n",
    "        mae, mse, acc, num_examples = 0., 0., 0., 0\n",
    "\n",
    "        for i, (features, targets) in enumerate(data_loader):\n",
    "\n",
    "            features = features.to(device)\n",
    "            targets = targets.float().to(device)\n",
    "\n",
    "            logits = model(features)\n",
    "            predicted_labels = corn_label_from_logits(logits).float()\n",
    "\n",
    "            num_examples += targets.size(0)\n",
    "            mae += torch.sum(torch.abs(predicted_labels - targets))\n",
    "            mse += torch.sum((predicted_labels - targets)**2)\n",
    "\n",
    "        mae = mae / num_examples\n",
    "        mse = mse / num_examples\n",
    "        return mae, mse"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_mae, train_mse = compute_mae_and_mse(model, train_loader, DEVICE)\n",
    "test_mae, test_mse = compute_mae_and_mse(model, test_loader, DEVICE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mean absolute error (train/test): 0.19 | 0.32\n",
      "Mean squared error (train/test): 0.20 | 0.35\n"
     ]
    }
   ],
   "source": [
    "print(f'Mean absolute error (train/test): {train_mae:.2f} | {test_mae:.2f}')\n",
    "print(f'Mean squared error (train/test): {train_mse:.2f} | {test_mse:.2f}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that MNIST is not an ordinal dataset (there is no order between the image categories), so computing the MAE or MSE doesn't really make sense but we use it anyways for demonstration purposes."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5 -- Rank probabilities from logits"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To obtain the rank probabilities from the logits, you can use the sigmoid function to get the conditional probabilities for each task and then compute the task probabilities via the chain rule for probabilities. Note that this is also done internally by the `corn_label_from_logits` we used above."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[9.8258e-01, 9.8255e-01, 9.8059e-01, 9.5894e-01],\n",
      "        [9.5521e-01, 8.4898e-02, 5.0589e-03, 1.8228e-03],\n",
      "        [9.9904e-01, 9.3805e-01, 7.2734e-02, 1.1431e-02],\n",
      "        [9.9465e-01, 3.3964e-02, 1.0907e-03, 4.0408e-04],\n",
      "        [8.2752e-01, 4.1369e-03, 4.1199e-05, 4.0809e-05],\n",
      "        [9.9777e-01, 9.9021e-01, 5.1255e-01, 1.6108e-01],\n",
      "        [9.9967e-01, 7.1187e-01, 5.8335e-03, 2.9053e-05],\n",
      "        [2.0897e-01, 5.6346e-05, 2.1329e-05, 1.9403e-05],\n",
      "        [9.9859e-01, 6.4831e-01, 2.3655e-02, 8.0152e-03],\n",
      "        [7.4509e-01, 2.3040e-03, 1.1488e-04, 8.7757e-05],\n",
      "        [9.9034e-01, 9.7828e-02, 1.0961e-02, 1.5772e-03],\n",
      "        [7.8535e-03, 4.3403e-05, 7.7310e-06, 7.2784e-06],\n",
      "        [3.9570e-02, 1.0581e-04, 4.5490e-06, 2.2578e-06],\n",
      "        [6.2016e-01, 3.0076e-02, 3.6441e-04, 3.0685e-04],\n",
      "        [9.9347e-01, 9.8310e-01, 3.7508e-01, 7.8585e-03],\n",
      "        [3.0705e-01, 3.1166e-04, 6.0899e-05, 5.3413e-05],\n",
      "        [9.9829e-01, 8.0362e-01, 2.0864e-02, 1.3445e-04],\n",
      "        [9.9912e-01, 9.9793e-01, 9.9118e-01, 3.2078e-02],\n",
      "        [8.3965e-01, 1.2979e-01, 2.0052e-03, 1.2908e-04],\n",
      "        [9.9663e-01, 8.7294e-01, 9.0855e-02, 4.4399e-03],\n",
      "        [1.1083e-02, 2.1253e-05, 1.0237e-05, 8.1861e-06],\n",
      "        [9.9885e-01, 9.9807e-01, 7.8067e-01, 1.5610e-01],\n",
      "        [9.9931e-01, 7.8717e-01, 5.2899e-01, 2.8850e-03],\n",
      "        [9.9917e-01, 9.9807e-01, 2.2598e-01, 9.3320e-04],\n",
      "        [9.9742e-01, 9.2505e-01, 1.3515e-02, 9.4619e-05],\n",
      "        [1.3924e-05, 6.6898e-08, 4.3272e-08, 4.3070e-08],\n",
      "        [9.9928e-01, 8.2273e-01, 2.8413e-01, 2.4039e-03],\n",
      "        [9.9870e-01, 7.0082e-01, 4.2848e-03, 1.7001e-04],\n",
      "        [9.8960e-01, 9.8807e-01, 9.8793e-01, 6.6704e-01],\n",
      "        [9.9391e-01, 4.9507e-01, 9.4512e-02, 1.7841e-02]], device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "logits = model(features)\n",
    "\n",
    "with torch.no_grad():\n",
    "    probas = torch.sigmoid(logits)\n",
    "    probas = torch.cumprod(probas, dim=1)\n",
    "    print(probas)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
